{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0c6cea67-3c4e-4798-ba5f-9d7191621dbb",
   "metadata": {
    "id": "0c6cea67-3c4e-4798-ba5f-9d7191621dbb"
   },
   "source": [
    "# SciBERT for Multi-Label Classification\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/center-for-threat-informed-defense/tram/blob/main/user_notebooks/fine_tune_multi_label.ipynb)\n",
    "\n",
    "This notebook allows one to continue fine-tuning our provided SciBERT-for-multilabel-sequence-classification on custom data.\n",
    "\n",
    "To start, first select `Runtime > Change runtime type`, and under `Hardware accelerator` select `GPU`. Then run the next two cells. The first cell will download the model and the Python dependencies. The second cell will load the model and set up the selectors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "719a7mZxyi1w",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 24075,
     "status": "ok",
     "timestamp": 1689378569589,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "719a7mZxyi1w",
    "outputId": "754d8cf0-b44f-4e51-865b-90055191289f"
   },
   "outputs": [],
   "source": [
    "!mkdir scibert_multi_label_model\n",
    "!wget https://ctidtram.blob.core.windows.net/tram-models/multi-label-20230803/config.json -O scibert_multi_label_model/config.json\n",
    "!wget https://ctidtram.blob.core.windows.net/tram-models/multi-label-20230803/pytorch_model.bin -O scibert_multi_label_model/pytorch_model.bin\n",
    "!pip install torch transformers pandas"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f55728b-50a9-4073-9570-29a58b8f972d",
   "metadata": {
    "id": "9f55728b-50a9-4073-9570-29a58b8f972d"
   },
   "source": [
    "This cell instantiates the label encoder. Do not modify this cell, as the classes (ie, ATT&CK techniques) and their order must match those the model expects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "75c82006-f3b4-409a-9bde-035eb14070b5",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 179
    },
    "executionInfo": {
     "elapsed": 1024,
     "status": "ok",
     "timestamp": 1689378570611,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "75c82006-f3b4-409a-9bde-035eb14070b5",
    "outputId": "b7a8b194-ce1e-4360-bed1-45a30acdcb34",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {color: black;background-color: white;}#sk-container-id-1 pre{padding: 0;}#sk-container-id-1 div.sk-toggleable {background-color: white;}#sk-container-id-1 label.sk-toggleable__label {cursor: pointer;display: block;width: 100%;margin-bottom: 0;padding: 0.3em;box-sizing: border-box;text-align: center;}#sk-container-id-1 label.sk-toggleable__label-arrow:before {content: \"▸\";float: left;margin-right: 0.25em;color: #696969;}#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {color: black;}#sk-container-id-1 div.sk-estimator:hover label.sk-toggleable__label-arrow:before {color: black;}#sk-container-id-1 div.sk-toggleable__content {max-height: 0;max-width: 0;overflow: hidden;text-align: left;background-color: #f0f8ff;}#sk-container-id-1 div.sk-toggleable__content pre {margin: 0.2em;color: black;border-radius: 0.25em;background-color: #f0f8ff;}#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {max-height: 200px;max-width: 100%;overflow: auto;}#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {content: \"▾\";}#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 input.sk-hidden--visually {border: 0;clip: rect(1px 1px 1px 1px);clip: rect(1px, 1px, 1px, 1px);height: 1px;margin: -1px;overflow: hidden;padding: 0;position: absolute;width: 1px;}#sk-container-id-1 div.sk-estimator {font-family: monospace;background-color: #f0f8ff;border: 1px dotted black;border-radius: 0.25em;box-sizing: border-box;margin-bottom: 0.5em;}#sk-container-id-1 div.sk-estimator:hover {background-color: #d4ebff;}#sk-container-id-1 div.sk-parallel-item::after {content: \"\";width: 100%;border-bottom: 1px solid gray;flex-grow: 1;}#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {background-color: #d4ebff;}#sk-container-id-1 div.sk-serial::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: 0;}#sk-container-id-1 div.sk-serial {display: flex;flex-direction: column;align-items: center;background-color: white;padding-right: 0.2em;padding-left: 0.2em;position: relative;}#sk-container-id-1 div.sk-item {position: relative;z-index: 1;}#sk-container-id-1 div.sk-parallel {display: flex;align-items: stretch;justify-content: center;background-color: white;position: relative;}#sk-container-id-1 div.sk-item::before, #sk-container-id-1 div.sk-parallel-item::before {content: \"\";position: absolute;border-left: 1px solid gray;box-sizing: border-box;top: 0;bottom: 0;left: 50%;z-index: -1;}#sk-container-id-1 div.sk-parallel-item {display: flex;flex-direction: column;z-index: 1;position: relative;background-color: white;}#sk-container-id-1 div.sk-parallel-item:first-child::after {align-self: flex-end;width: 50%;}#sk-container-id-1 div.sk-parallel-item:last-child::after {align-self: flex-start;width: 50%;}#sk-container-id-1 div.sk-parallel-item:only-child::after {width: 0;}#sk-container-id-1 div.sk-dashed-wrapped {border: 1px dashed gray;margin: 0 0.4em 0.5em 0.4em;box-sizing: border-box;padding-bottom: 0.4em;background-color: white;}#sk-container-id-1 div.sk-label label {font-family: monospace;font-weight: bold;display: inline-block;line-height: 1.2em;}#sk-container-id-1 div.sk-label-container {text-align: center;}#sk-container-id-1 div.sk-container {/* jupyter's `normalize.less` sets `[hidden] { display: none; }` but bootstrap.min.css set `[hidden] { display: none !important; }` so we also need the `!important` here to be able to override the default hidden behavior on the sphinx rendered scikit-learn.org. See: https://github.com/scikit-learn/scikit-learn/issues/21755 */display: inline-block !important;position: relative;}#sk-container-id-1 div.sk-text-repr-fallback {display: none;}</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>MultiLabelBinarizer(classes=[&#x27;T1003.001&#x27;, &#x27;T1005&#x27;, &#x27;T1012&#x27;, &#x27;T1016&#x27;,\n",
       "                             &#x27;T1021.001&#x27;, &#x27;T1027&#x27;, &#x27;T1033&#x27;, &#x27;T1036.005&#x27;,\n",
       "                             &#x27;T1041&#x27;, &#x27;T1047&#x27;, &#x27;T1053.005&#x27;, &#x27;T1055&#x27;,\n",
       "                             &#x27;T1056.001&#x27;, &#x27;T1057&#x27;, &#x27;T1059.003&#x27;, &#x27;T1068&#x27;,\n",
       "                             &#x27;T1070.004&#x27;, &#x27;T1071.001&#x27;, &#x27;T1072&#x27;, &#x27;T1074.001&#x27;,\n",
       "                             &#x27;T1078&#x27;, &#x27;T1082&#x27;, &#x27;T1083&#x27;, &#x27;T1090&#x27;, &#x27;T1095&#x27;,\n",
       "                             &#x27;T1105&#x27;, &#x27;T1106&#x27;, &#x27;T1110&#x27;, &#x27;T1112&#x27;, &#x27;T1113&#x27;, ...])</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label sk-toggleable__label-arrow\">MultiLabelBinarizer</label><div class=\"sk-toggleable__content\"><pre>MultiLabelBinarizer(classes=[&#x27;T1003.001&#x27;, &#x27;T1005&#x27;, &#x27;T1012&#x27;, &#x27;T1016&#x27;,\n",
       "                             &#x27;T1021.001&#x27;, &#x27;T1027&#x27;, &#x27;T1033&#x27;, &#x27;T1036.005&#x27;,\n",
       "                             &#x27;T1041&#x27;, &#x27;T1047&#x27;, &#x27;T1053.005&#x27;, &#x27;T1055&#x27;,\n",
       "                             &#x27;T1056.001&#x27;, &#x27;T1057&#x27;, &#x27;T1059.003&#x27;, &#x27;T1068&#x27;,\n",
       "                             &#x27;T1070.004&#x27;, &#x27;T1071.001&#x27;, &#x27;T1072&#x27;, &#x27;T1074.001&#x27;,\n",
       "                             &#x27;T1078&#x27;, &#x27;T1082&#x27;, &#x27;T1083&#x27;, &#x27;T1090&#x27;, &#x27;T1095&#x27;,\n",
       "                             &#x27;T1105&#x27;, &#x27;T1106&#x27;, &#x27;T1110&#x27;, &#x27;T1112&#x27;, &#x27;T1113&#x27;, ...])</pre></div></div></div></div></div>"
      ],
      "text/plain": [
       "MultiLabelBinarizer(classes=['T1003.001', 'T1005', 'T1012', 'T1016',\n",
       "                             'T1021.001', 'T1027', 'T1033', 'T1036.005',\n",
       "                             'T1041', 'T1047', 'T1053.005', 'T1055',\n",
       "                             'T1056.001', 'T1057', 'T1059.003', 'T1068',\n",
       "                             'T1070.004', 'T1071.001', 'T1072', 'T1074.001',\n",
       "                             'T1078', 'T1082', 'T1083', 'T1090', 'T1095',\n",
       "                             'T1105', 'T1106', 'T1110', 'T1112', 'T1113', ...])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.preprocessing import MultiLabelBinarizer as MLB\n",
    "\n",
    "CLASSES = [\n",
    "   'T1003.001', 'T1005', 'T1012', 'T1016', 'T1021.001', 'T1027',\n",
    "   'T1033', 'T1036.005', 'T1041', 'T1047', 'T1053.005', 'T1055',\n",
    "   'T1056.001', 'T1057', 'T1059.003', 'T1068', 'T1070.004',\n",
    "   'T1071.001', 'T1072', 'T1074.001', 'T1078', 'T1082', 'T1083',\n",
    "   'T1090', 'T1095', 'T1105', 'T1106', 'T1110', 'T1112', 'T1113',\n",
    "   'T1140', 'T1190', 'T1204.002', 'T1210', 'T1218.011', 'T1219',\n",
    "   'T1484.001', 'T1518.001', 'T1543.003', 'T1547.001', 'T1548.002',\n",
    "   'T1552.001', 'T1557.001', 'T1562.001', 'T1564.001', 'T1566.001',\n",
    "   'T1569.002', 'T1570', 'T1573.001', 'T1574.002'\n",
    "]\n",
    "\n",
    "mlb = MLB(classes=CLASSES)\n",
    "mlb.fit([[c] for c in CLASSES])\n",
    "\n",
    "mlb"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "296b611b-67ae-4201-a559-2f5f56ab2723",
   "metadata": {
    "id": "296b611b-67ae-4201-a559-2f5f56ab2723"
   },
   "source": [
    "This cell is for loading the training data. You will need to modify this cell to load your data. Ensure that by the end of this cell, a DataFrame has been assigned to the variable `data` that has a `sentence` column containing the sentences, and a `labels` column containing lists (or other container types) of strings, where those strings are the ATT&CK IDs that this model can classify. The lists can be empty for negative examples. It does not matter how the DataFrame is indexed or what other columns with other names, if any, it has.\n",
    "\n",
    "For demonstration purposes, we will use the same multi-label data that was produced during this TRAM effort, even though the model was trained on this data already. This cell is only present to show the expected format of the `data` DataFrame, and is not intended to be run as shown."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2622429e-512d-443e-895b-eeec53afcecc",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 874,
     "status": "ok",
     "timestamp": 1689378573840,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "2622429e-512d-443e-895b-eeec53afcecc",
    "outputId": "78a17cb4-12f0-4a2f-8872-ff179da8ec50",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "\n",
       "  <div id=\"df-1b9d5166-0a44-40b3-808b-cfba2b882508\">\n",
       "    <div class=\"colab-df-container\">\n",
       "      <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>sentence</th>\n",
       "      <th>labels</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>title: NotPetya Technical Analysis – A Triple ...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Executive Summary This technical analysis prov...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>For more information on CrowdStrike’s proactiv...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>NotPetya combines ransomware with the ability ...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>It spreads to Microsoft Windows machines using...</td>\n",
       "      <td>[T1210]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>495</th>\n",
       "      <td>The following is a list of services, hardcoded...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>496</th>\n",
       "      <td>QuickBooks.</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>497</th>\n",
       "      <td>FCSmemtasmepocsPDVFSServiceQBCFMonitorServiceQ...</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>498</th>\n",
       "      <td>#recycle$Recycle.</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>499</th>\n",
       "      <td>Bin1_config.iniAhnlabAll UsersAppDataAUTOEXEC....</td>\n",
       "      <td>[]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>500 rows × 2 columns</p>\n",
       "</div>\n",
       "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-1b9d5166-0a44-40b3-808b-cfba2b882508')\"\n",
       "              title=\"Convert this dataframe to an interactive table.\"\n",
       "              style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "       width=\"24px\">\n",
       "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
       "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
       "  </svg>\n",
       "      </button>\n",
       "\n",
       "\n",
       "\n",
       "    <div id=\"df-31eb2f7b-1cc8-451d-9cb0-15c5ad76c4b2\">\n",
       "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-31eb2f7b-1cc8-451d-9cb0-15c5ad76c4b2')\"\n",
       "              title=\"Suggest charts.\"\n",
       "              style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "      </button>\n",
       "    </div>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "    background-color: #E8F0FE;\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: #1967D2;\n",
       "    height: 32px;\n",
       "    padding: 0 0 0 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: #E2EBFA;\n",
       "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: #174EA6;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "    background-color: #3B4455;\n",
       "    fill: #D2E3FC;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart:hover {\n",
       "    background-color: #434B5C;\n",
       "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "    fill: #FFFFFF;\n",
       "  }\n",
       "</style>\n",
       "\n",
       "    <script>\n",
       "      async function quickchart(key) {\n",
       "        const containerElement = document.querySelector('#' + key);\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      }\n",
       "    </script>\n",
       "\n",
       "      <script>\n",
       "\n",
       "function displayQuickchartButton(domScope) {\n",
       "  let quickchartButtonEl =\n",
       "    domScope.querySelector('#df-31eb2f7b-1cc8-451d-9cb0-15c5ad76c4b2 button.colab-df-quickchart');\n",
       "  quickchartButtonEl.style.display =\n",
       "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "}\n",
       "\n",
       "        displayQuickchartButton(document);\n",
       "      </script>\n",
       "      <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      flex-wrap:wrap;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "      <script>\n",
       "        const buttonEl =\n",
       "          document.querySelector('#df-1b9d5166-0a44-40b3-808b-cfba2b882508 button.colab-df-convert');\n",
       "        buttonEl.style.display =\n",
       "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "        async function convertToInteractive(key) {\n",
       "          const element = document.querySelector('#df-1b9d5166-0a44-40b3-808b-cfba2b882508');\n",
       "          const dataTable =\n",
       "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                     [key], {});\n",
       "          if (!dataTable) return;\n",
       "\n",
       "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "            + ' to learn more about interactive tables.';\n",
       "          element.innerHTML = '';\n",
       "          dataTable['output_type'] = 'display_data';\n",
       "          await google.colab.output.renderOutput(dataTable, element);\n",
       "          const docLink = document.createElement('div');\n",
       "          docLink.innerHTML = docLinkHtml;\n",
       "          element.appendChild(docLink);\n",
       "        }\n",
       "      </script>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                                              sentence   labels\n",
       "0    title: NotPetya Technical Analysis – A Triple ...       []\n",
       "1    Executive Summary This technical analysis prov...       []\n",
       "2    For more information on CrowdStrike’s proactiv...       []\n",
       "3    NotPetya combines ransomware with the ability ...       []\n",
       "4    It spreads to Microsoft Windows machines using...  [T1210]\n",
       "..                                                 ...      ...\n",
       "495  The following is a list of services, hardcoded...       []\n",
       "496                                        QuickBooks.       []\n",
       "497  FCSmemtasmepocsPDVFSServiceQBCFMonitorServiceQ...       []\n",
       "498                                  #recycle$Recycle.       []\n",
       "499  Bin1_config.iniAhnlabAll UsersAppDataAUTOEXEC....       []\n",
       "\n",
       "[500 rows x 2 columns]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_json('multi_label.json').drop(columns='doc_title').head(500)\n",
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7136c953-fb1a-4e8a-887f-1ea4ec92add4",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 81,
     "referenced_widgets": [
      "66eb7fdf62344f2fb31bf787d7a43d39",
      "be7219f5baa0404c957c33f57b90af43",
      "11fe0689ac4a4dca9cf2f88793bdceb6",
      "a031373b5ae04224996c0d377603028a",
      "fbb2fdd3537e4ec8a2d1c7304414f934",
      "f1ed4a086a9d46deac52a934c65bd46f",
      "fbafe2556e6e4b97840c34655b826810",
      "44e810873d234087a8de6598e82b62c8",
      "a8a3deb0863a456ab9303dc4d4fdb63d",
      "5f89cc094cac4f77b3058819bce39af3",
      "97250b845e38433c94c98a040c58fcc2",
      "29e5a977ca2e4c95bd106594066d365b",
      "8146378b52074fc6bd8042aa07c855ac",
      "3bbfc7ca557d438193cd246d85e93ceb",
      "dd9b9569b2394e0fb563838c553961a0",
      "fbdb96e08dc94bfc98bb3da709521bb1",
      "d07290d0e08a4923b7175e000576a767",
      "bdd93e271bbd45fdae397e9918a8a8b7",
      "93c6d77bfe614b258242919b33c4a6f4",
      "912dfc4ef2b94a58b924e21dbe7222df",
      "04b96cb68fb6460294d0912c6dc2c4af",
      "d197eece69da4eb29203ce2498369c82"
     ]
    },
    "executionInfo": {
     "elapsed": 15061,
     "status": "ok",
     "timestamp": 1689378592260,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "7136c953-fb1a-4e8a-887f-1ea4ec92add4",
    "outputId": "6278bd9c-c5cc-4365-b48d-35e4e3868e2b",
    "tags": []
   },
   "outputs": [],
   "source": [
    "import transformers\n",
    "import torch\n",
    "\n",
    "cuda = torch.device('cuda')\n",
    "\n",
    "tokenizer = transformers.BertTokenizer.from_pretrained(\"allenai/scibert_scivocab_uncased\", max_length=512)\n",
    "bert = transformers.BertForSequenceClassification.from_pretrained('scibert_multi_label_model').to(cuda).train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e0e7574e-60c1-485e-96b1-a161babc69ef",
   "metadata": {
    "executionInfo": {
     "elapsed": 226,
     "status": "ok",
     "timestamp": 1689378596099,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "e0e7574e-60c1-485e-96b1-a161babc69ef",
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "train, test = train_test_split(data, test_size=0.2, shuffle=True)\n",
    "\n",
    "def _load_data(x, y, batch_size=10):\n",
    "    x_len, y_len = x.shape[0], y.shape[0]\n",
    "    assert x_len == y_len\n",
    "    for i in range(0, x_len, batch_size):\n",
    "        slc = slice(i, i + batch_size)\n",
    "        yield x[slc].to(cuda), y[slc].to(cuda)\n",
    "\n",
    "def _tokenize(instances: list[str]):\n",
    "    return tokenizer(instances, return_tensors='pt', padding='max_length', truncation=True, max_length=512).input_ids\n",
    "\n",
    "def _encode_labels(labels):\n",
    "    \"\"\":labels: should be the `labels` column (a Series) of the DataFrame\"\"\"\n",
    "    return torch.Tensor(mlb.transform(labels.to_numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8ad37bdc-842f-46d7-b929-d0db6eae45e0",
   "metadata": {
    "executionInfo": {
     "elapsed": 485,
     "status": "ok",
     "timestamp": 1689378605219,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "8ad37bdc-842f-46d7-b929-d0db6eae45e0",
    "tags": []
   },
   "outputs": [],
   "source": [
    "x_train = _tokenize(train['sentence'].tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "58cb43bb-ccf8-4ecc-849e-be8a52c26585",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 117,
     "status": "ok",
     "timestamp": 1689378609654,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "58cb43bb-ccf8-4ecc-849e-be8a52c26585",
    "outputId": "955c99bb-2c80-4329-a48a-d9cec2b00b6b",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0.,  ..., 1., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train = _encode_labels(train['labels'])\n",
    "y_train"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3504844-b451-490e-8bad-c9d0f975999e",
   "metadata": {
    "id": "c3504844-b451-490e-8bad-c9d0f975999e"
   },
   "source": [
    "This cell contains the training loop. You may change the `NUM_EPOCHS` value to any integer you would like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce77f9f-76d2-4e4c-90b4-77106f0154cb",
   "metadata": {
    "id": "cce77f9f-76d2-4e4c-90b4-77106f0154cb",
    "tags": []
   },
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 3\n",
    "\n",
    "from statistics import mean\n",
    "\n",
    "from tqdm import tqdm\n",
    "from torch.optim import AdamW\n",
    "\n",
    "optim = AdamW(bert.parameters(), lr=2e-5, eps=1e-8)\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    epoch_losses = []\n",
    "    for x, y in tqdm(_load_data(x_train, y_train, batch_size=10)):\n",
    "        bert.zero_grad()\n",
    "        out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)\n",
    "        epoch_losses.append(out.loss.item())\n",
    "        out.loss.backward()\n",
    "        optim.step()\n",
    "    print(f\"epoch {epoch + 1} loss: {mean(epoch_losses)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9af32bbd-2337-4d2d-bf71-0dacb40afb43",
   "metadata": {
    "id": "9af32bbd-2337-4d2d-bf71-0dacb40afb43"
   },
   "source": [
    "If the loss from the last iteration was not to your liking, do not re-run the previous cell. Uncomment the following cell and run it for however many additional epochs you would like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa5a01a0-eb3b-4564-bd4a-a43f819b9aeb",
   "metadata": {
    "id": "fa5a01a0-eb3b-4564-bd4a-a43f819b9aeb"
   },
   "outputs": [],
   "source": [
    "# NUM_EXTRA_EPOCHS = 1\n",
    "# for epoch in range(NUM_EXTRA_EPOCHS):\n",
    "#     epoch_losses = []\n",
    "#     for x, y in tqdm(_load_data(x_train, y_train, batch_size=10)):\n",
    "#         bert.zero_grad()\n",
    "#         out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int), labels=y)\n",
    "#         epoch_losses.append(out.loss.item())\n",
    "#         out.loss.backward()\n",
    "#         optim.step()\n",
    "#     print(f\"epoch {epoch + 1} loss: {mean(epoch_losses)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d7311a4-6ac8-4397-8497-369865a7c924",
   "metadata": {
    "id": "9d7311a4-6ac8-4397-8497-369865a7c924"
   },
   "source": [
    "The next cells evaluate the performance after the additional fine-tuning. The performance scores on the example data will be high, as the model has already been trained on most of these instances."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9eff1660-da8b-4d9d-bdcd-53d01d5a05a4",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "executionInfo": {
     "elapsed": 3145,
     "status": "ok",
     "timestamp": 1689378634778,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "9eff1660-da8b-4d9d-bdcd-53d01d5a05a4",
    "outputId": "23434cbc-dfd5-4737-cd0b-2d7dcf498278",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "\n",
       "  <div id=\"df-b7f64064-9425-4c68-b047-5cdc64fa5049\">\n",
       "    <div class=\"colab-df-container\">\n",
       "      <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>predicted</th>\n",
       "      <th>actual</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>(T1562.001)</td>\n",
       "      <td>(T1562.001, T1484.001)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>()</td>\n",
       "      <td>()</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>()</td>\n",
       "      <td>()</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>()</td>\n",
       "      <td>()</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>()</td>\n",
       "      <td>(T1106)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>95</th>\n",
       "      <td>()</td>\n",
       "      <td>()</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>()</td>\n",
       "      <td>()</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>97</th>\n",
       "      <td>(T1140)</td>\n",
       "      <td>(T1027, T1140)</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>98</th>\n",
       "      <td>()</td>\n",
       "      <td>()</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>99</th>\n",
       "      <td>()</td>\n",
       "      <td>(T1005)</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>100 rows × 2 columns</p>\n",
       "</div>\n",
       "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b7f64064-9425-4c68-b047-5cdc64fa5049')\"\n",
       "              title=\"Convert this dataframe to an interactive table.\"\n",
       "              style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "       width=\"24px\">\n",
       "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
       "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
       "  </svg>\n",
       "      </button>\n",
       "\n",
       "\n",
       "\n",
       "    <div id=\"df-fc09d5bd-7ab5-4c66-940f-bfd1cf957e7e\">\n",
       "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-fc09d5bd-7ab5-4c66-940f-bfd1cf957e7e')\"\n",
       "              title=\"Suggest charts.\"\n",
       "              style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "      </button>\n",
       "    </div>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "    background-color: #E8F0FE;\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: #1967D2;\n",
       "    height: 32px;\n",
       "    padding: 0 0 0 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: #E2EBFA;\n",
       "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: #174EA6;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "    background-color: #3B4455;\n",
       "    fill: #D2E3FC;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart:hover {\n",
       "    background-color: #434B5C;\n",
       "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "    fill: #FFFFFF;\n",
       "  }\n",
       "</style>\n",
       "\n",
       "    <script>\n",
       "      async function quickchart(key) {\n",
       "        const containerElement = document.querySelector('#' + key);\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      }\n",
       "    </script>\n",
       "\n",
       "      <script>\n",
       "\n",
       "function displayQuickchartButton(domScope) {\n",
       "  let quickchartButtonEl =\n",
       "    domScope.querySelector('#df-fc09d5bd-7ab5-4c66-940f-bfd1cf957e7e button.colab-df-quickchart');\n",
       "  quickchartButtonEl.style.display =\n",
       "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "}\n",
       "\n",
       "        displayQuickchartButton(document);\n",
       "      </script>\n",
       "      <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      flex-wrap:wrap;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "      <script>\n",
       "        const buttonEl =\n",
       "          document.querySelector('#df-b7f64064-9425-4c68-b047-5cdc64fa5049 button.colab-df-convert');\n",
       "        buttonEl.style.display =\n",
       "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "        async function convertToInteractive(key) {\n",
       "          const element = document.querySelector('#df-b7f64064-9425-4c68-b047-5cdc64fa5049');\n",
       "          const dataTable =\n",
       "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                     [key], {});\n",
       "          if (!dataTable) return;\n",
       "\n",
       "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "            + ' to learn more about interactive tables.';\n",
       "          element.innerHTML = '';\n",
       "          dataTable['output_type'] = 'display_data';\n",
       "          await google.colab.output.renderOutput(dataTable, element);\n",
       "          const docLink = document.createElement('div');\n",
       "          docLink.innerHTML = docLinkHtml;\n",
       "          element.appendChild(docLink);\n",
       "        }\n",
       "      </script>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "      predicted                  actual\n",
       "0   (T1562.001)  (T1562.001, T1484.001)\n",
       "1            ()                      ()\n",
       "2            ()                      ()\n",
       "3            ()                      ()\n",
       "4            ()                 (T1106)\n",
       "..          ...                     ...\n",
       "95           ()                      ()\n",
       "96           ()                      ()\n",
       "97      (T1140)          (T1027, T1140)\n",
       "98           ()                      ()\n",
       "99           ()                 (T1005)\n",
       "\n",
       "[100 rows x 2 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bert.eval()\n",
    "\n",
    "x_test = _tokenize(test['sentence'].tolist())\n",
    "y_test = _encode_labels(test['labels'])\n",
    "\n",
    "batch_size = 20\n",
    "preds = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(0, x_test.shape[0], batch_size):\n",
    "        x = x_test[i : i + batch_size].to(cuda)\n",
    "        out = bert(x, attention_mask=x.ne(tokenizer.pad_token_id).to(int))\n",
    "        preds.extend(out.logits.to('cpu'))\n",
    "\n",
    "binary_preds = torch.vstack(preds).sigmoid().gt(.5).to(int)\n",
    "\n",
    "preds_series = pd.Series(mlb.inverse_transform(binary_preds)).apply(frozenset)\n",
    "actual_series = pd.Series(mlb.inverse_transform(y_test)).apply(frozenset)\n",
    "results = pd.concat({'predicted': preds_series, 'actual': actual_series}, axis=1)\n",
    "\n",
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "93acae94-ea14-4d49-b769-d778a3f6457f",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 520
    },
    "executionInfo": {
     "elapsed": 236,
     "status": "ok",
     "timestamp": 1689378657099,
     "user": {
      "displayName": "tram",
      "userId": "11961082670110789134"
     },
     "user_tz": 240
    },
    "id": "93acae94-ea14-4d49-b769-d778a3f6457f",
    "outputId": "d671699a-7188-495b-ff46-589ff306d7f0",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "\n",
       "  <div id=\"df-2a34b3c2-936d-4f3d-84e8-e0af0ad73283\">\n",
       "    <div class=\"colab-df-container\">\n",
       "      <div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>P</th>\n",
       "      <th>R</th>\n",
       "      <th>F1</th>\n",
       "      <th>#</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>T1059.003</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1140</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1562.001</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1055</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1574.002</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1053.005</th>\n",
       "      <td>0.750000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.857143</td>\n",
       "      <td>3.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1027</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.666667</td>\n",
       "      <td>8.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1106</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>6.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1484.001</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>4.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1005</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1057</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1047</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>T1210</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>(macro)</th>\n",
       "      <td>0.596154</td>\n",
       "      <td>0.525641</td>\n",
       "      <td>0.540293</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>(micro)</th>\n",
       "      <td>0.941176</td>\n",
       "      <td>0.484848</td>\n",
       "      <td>0.640000</td>\n",
       "      <td>NaN</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>\n",
       "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-2a34b3c2-936d-4f3d-84e8-e0af0ad73283')\"\n",
       "              title=\"Convert this dataframe to an interactive table.\"\n",
       "              style=\"display:none;\">\n",
       "\n",
       "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "       width=\"24px\">\n",
       "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
       "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
       "  </svg>\n",
       "      </button>\n",
       "\n",
       "\n",
       "\n",
       "    <div id=\"df-5f6255d8-bf48-44c0-bdc2-77ea25494849\">\n",
       "      <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-5f6255d8-bf48-44c0-bdc2-77ea25494849')\"\n",
       "              title=\"Suggest charts.\"\n",
       "              style=\"display:none;\">\n",
       "\n",
       "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
       "     width=\"24px\">\n",
       "    <g>\n",
       "        <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
       "    </g>\n",
       "</svg>\n",
       "      </button>\n",
       "    </div>\n",
       "\n",
       "<style>\n",
       "  .colab-df-quickchart {\n",
       "    background-color: #E8F0FE;\n",
       "    border: none;\n",
       "    border-radius: 50%;\n",
       "    cursor: pointer;\n",
       "    display: none;\n",
       "    fill: #1967D2;\n",
       "    height: 32px;\n",
       "    padding: 0 0 0 0;\n",
       "    width: 32px;\n",
       "  }\n",
       "\n",
       "  .colab-df-quickchart:hover {\n",
       "    background-color: #E2EBFA;\n",
       "    box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "    fill: #174EA6;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart {\n",
       "    background-color: #3B4455;\n",
       "    fill: #D2E3FC;\n",
       "  }\n",
       "\n",
       "  [theme=dark] .colab-df-quickchart:hover {\n",
       "    background-color: #434B5C;\n",
       "    box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "    filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "    fill: #FFFFFF;\n",
       "  }\n",
       "</style>\n",
       "\n",
       "    <script>\n",
       "      async function quickchart(key) {\n",
       "        const containerElement = document.querySelector('#' + key);\n",
       "        const charts = await google.colab.kernel.invokeFunction(\n",
       "            'suggestCharts', [key], {});\n",
       "      }\n",
       "    </script>\n",
       "\n",
       "      <script>\n",
       "\n",
       "function displayQuickchartButton(domScope) {\n",
       "  let quickchartButtonEl =\n",
       "    domScope.querySelector('#df-5f6255d8-bf48-44c0-bdc2-77ea25494849 button.colab-df-quickchart');\n",
       "  quickchartButtonEl.style.display =\n",
       "    google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "}\n",
       "\n",
       "        displayQuickchartButton(document);\n",
       "      </script>\n",
       "      <style>\n",
       "    .colab-df-container {\n",
       "      display:flex;\n",
       "      flex-wrap:wrap;\n",
       "      gap: 12px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert {\n",
       "      background-color: #E8F0FE;\n",
       "      border: none;\n",
       "      border-radius: 50%;\n",
       "      cursor: pointer;\n",
       "      display: none;\n",
       "      fill: #1967D2;\n",
       "      height: 32px;\n",
       "      padding: 0 0 0 0;\n",
       "      width: 32px;\n",
       "    }\n",
       "\n",
       "    .colab-df-convert:hover {\n",
       "      background-color: #E2EBFA;\n",
       "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
       "      fill: #174EA6;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert {\n",
       "      background-color: #3B4455;\n",
       "      fill: #D2E3FC;\n",
       "    }\n",
       "\n",
       "    [theme=dark] .colab-df-convert:hover {\n",
       "      background-color: #434B5C;\n",
       "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
       "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
       "      fill: #FFFFFF;\n",
       "    }\n",
       "  </style>\n",
       "\n",
       "      <script>\n",
       "        const buttonEl =\n",
       "          document.querySelector('#df-2a34b3c2-936d-4f3d-84e8-e0af0ad73283 button.colab-df-convert');\n",
       "        buttonEl.style.display =\n",
       "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
       "\n",
       "        async function convertToInteractive(key) {\n",
       "          const element = document.querySelector('#df-2a34b3c2-936d-4f3d-84e8-e0af0ad73283');\n",
       "          const dataTable =\n",
       "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
       "                                                     [key], {});\n",
       "          if (!dataTable) return;\n",
       "\n",
       "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
       "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
       "            + ' to learn more about interactive tables.';\n",
       "          element.innerHTML = '';\n",
       "          dataTable['output_type'] = 'display_data';\n",
       "          await google.colab.output.renderOutput(dataTable, element);\n",
       "          const docLink = document.createElement('div');\n",
       "          docLink.innerHTML = docLinkHtml;\n",
       "          element.appendChild(docLink);\n",
       "        }\n",
       "      </script>\n",
       "    </div>\n",
       "  </div>\n"
      ],
      "text/plain": [
       "                  P         R        F1    #\n",
       "T1059.003  1.000000  1.000000  1.000000  2.0\n",
       "T1140      1.000000  1.000000  1.000000  2.0\n",
       "T1562.001  1.000000  1.000000  1.000000  1.0\n",
       "T1055      1.000000  1.000000  1.000000  1.0\n",
       "T1574.002  1.000000  1.000000  1.000000  1.0\n",
       "T1053.005  0.750000  1.000000  0.857143  3.0\n",
       "T1027      1.000000  0.500000  0.666667  8.0\n",
       "T1106      1.000000  0.333333  0.500000  6.0\n",
       "T1484.001  0.000000  0.000000  0.000000  4.0\n",
       "T1005      0.000000  0.000000  0.000000  2.0\n",
       "T1057      0.000000  0.000000  0.000000  1.0\n",
       "T1047      0.000000  0.000000  0.000000  1.0\n",
       "T1210      0.000000  0.000000  0.000000  1.0\n",
       "(macro)    0.596154  0.525641  0.540293  NaN\n",
       "(micro)    0.941176  0.484848  0.640000  NaN"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tp = results.apply((lambda r: r.predicted & r.actual), axis=1).explode().value_counts()\n",
    "fp = results.apply((lambda r: r.predicted - r.actual), axis=1).explode().value_counts()\n",
    "fn = results.apply((lambda r: r.actual - r.predicted), axis=1).explode().value_counts()\n",
    "counts = pd.concat({'tp': tp, 'fp': fp, 'fn': fn}, axis=1).fillna(0).astype(int)\n",
    "\n",
    "support = actual_series.explode().value_counts().rename('#')\n",
    "\n",
    "p = counts.tp.div(counts.tp + counts.fp).fillna(0)\n",
    "r = counts.tp.div(counts.tp + counts.fn).fillna(0)\n",
    "f1 = (2 * p * r) / (p + r)\n",
    "\n",
    "scores = pd.concat({'P': p, 'R': r, 'F1': f1}, axis=1).fillna(0).sort_values(by='F1', ascending=False)\n",
    "\n",
    "# calculate macro scores\n",
    "scores.loc['(macro)'] = scores.mean()\n",
    "\n",
    "# calculate micro scores\n",
    "micro = counts.sum()\n",
    "scores.loc['(micro)', 'P'] = mP = micro.tp / (micro.tp + micro.fp)\n",
    "scores.loc['(micro)', 'R'] = mR = micro.tp / (micro.tp + micro.fn)\n",
    "scores.loc['(micro)', 'F1'] = (2 * mP * mR) / (mP + mR)\n",
    "\n",
    "scores.join(support)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "04b96cb68fb6460294d0912c6dc2c4af": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "11fe0689ac4a4dca9cf2f88793bdceb6": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_44e810873d234087a8de6598e82b62c8",
      "max": 227845,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_a8a3deb0863a456ab9303dc4d4fdb63d",
      "value": 227845
     }
    },
    "29e5a977ca2e4c95bd106594066d365b": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_8146378b52074fc6bd8042aa07c855ac",
       "IPY_MODEL_3bbfc7ca557d438193cd246d85e93ceb",
       "IPY_MODEL_dd9b9569b2394e0fb563838c553961a0"
      ],
      "layout": "IPY_MODEL_fbdb96e08dc94bfc98bb3da709521bb1"
     }
    },
    "3bbfc7ca557d438193cd246d85e93ceb": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "FloatProgressModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "FloatProgressModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "ProgressView",
      "bar_style": "success",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_93c6d77bfe614b258242919b33c4a6f4",
      "max": 385,
      "min": 0,
      "orientation": "horizontal",
      "style": "IPY_MODEL_912dfc4ef2b94a58b924e21dbe7222df",
      "value": 385
     }
    },
    "44e810873d234087a8de6598e82b62c8": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "5f89cc094cac4f77b3058819bce39af3": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "66eb7fdf62344f2fb31bf787d7a43d39": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HBoxModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HBoxModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HBoxView",
      "box_style": "",
      "children": [
       "IPY_MODEL_be7219f5baa0404c957c33f57b90af43",
       "IPY_MODEL_11fe0689ac4a4dca9cf2f88793bdceb6",
       "IPY_MODEL_a031373b5ae04224996c0d377603028a"
      ],
      "layout": "IPY_MODEL_fbb2fdd3537e4ec8a2d1c7304414f934"
     }
    },
    "8146378b52074fc6bd8042aa07c855ac": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_d07290d0e08a4923b7175e000576a767",
      "placeholder": "​",
      "style": "IPY_MODEL_bdd93e271bbd45fdae397e9918a8a8b7",
      "value": "Downloading (…)lve/main/config.json: 100%"
     }
    },
    "912dfc4ef2b94a58b924e21dbe7222df": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "93c6d77bfe614b258242919b33c4a6f4": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "97250b845e38433c94c98a040c58fcc2": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "a031373b5ae04224996c0d377603028a": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_5f89cc094cac4f77b3058819bce39af3",
      "placeholder": "​",
      "style": "IPY_MODEL_97250b845e38433c94c98a040c58fcc2",
      "value": " 228k/228k [00:00&lt;00:00, 3.54MB/s]"
     }
    },
    "a8a3deb0863a456ab9303dc4d4fdb63d": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "ProgressStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "ProgressStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "bar_color": null,
      "description_width": ""
     }
    },
    "bdd93e271bbd45fdae397e9918a8a8b7": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "be7219f5baa0404c957c33f57b90af43": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_f1ed4a086a9d46deac52a934c65bd46f",
      "placeholder": "​",
      "style": "IPY_MODEL_fbafe2556e6e4b97840c34655b826810",
      "value": "Downloading (…)solve/main/vocab.txt: 100%"
     }
    },
    "d07290d0e08a4923b7175e000576a767": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "d197eece69da4eb29203ce2498369c82": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "dd9b9569b2394e0fb563838c553961a0": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "HTMLModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "HTMLModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/controls",
      "_view_module_version": "1.5.0",
      "_view_name": "HTMLView",
      "description": "",
      "description_tooltip": null,
      "layout": "IPY_MODEL_04b96cb68fb6460294d0912c6dc2c4af",
      "placeholder": "​",
      "style": "IPY_MODEL_d197eece69da4eb29203ce2498369c82",
      "value": " 385/385 [00:00&lt;00:00, 12.2kB/s]"
     }
    },
    "f1ed4a086a9d46deac52a934c65bd46f": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fbafe2556e6e4b97840c34655b826810": {
     "model_module": "@jupyter-widgets/controls",
     "model_module_version": "1.5.0",
     "model_name": "DescriptionStyleModel",
     "state": {
      "_model_module": "@jupyter-widgets/controls",
      "_model_module_version": "1.5.0",
      "_model_name": "DescriptionStyleModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "StyleView",
      "description_width": ""
     }
    },
    "fbb2fdd3537e4ec8a2d1c7304414f934": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "fbdb96e08dc94bfc98bb3da709521bb1": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
